{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:50:32.013886Z",
     "start_time": "2020-05-08T04:50:32.010359Z"
    }
   },
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1、导入相关的库\n",
    "- Transformer里面是关于Transformer模型的函数\n",
    "- util里面是相关的数据读取文件\n",
    "- train内是相关的训练和测试函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:50:32.449156Z",
     "start_time": "2020-05-08T04:50:32.026571Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from Transformer import *\n",
    "from util import *\n",
    "from train import *\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2,3,4\"\n",
    "device = torch.device('cuda')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2、设置相关的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:50:32.458145Z",
     "start_time": "2020-05-08T04:50:32.451725Z"
    }
   },
   "outputs": [],
   "source": [
    "embedding_size = 32  # token的维度\n",
    "num_layers = 2       # 编码器和解码器的层数，这里两者层数相同，也可以不同\n",
    "dropout = 0.05       # 所有层的droprate都相同，也可以不同\n",
    "batch_size = 64      # 批次\n",
    "num_steps = 10       # 预测步长\n",
    "factor = 1           # 学习率因子\n",
    "warmup = 2000        # 学习率上升步长\n",
    "lr, num_epochs, ctx = 0.005, 500, device  # 学习率；周期；设备\n",
    "num_hiddens, num_heads = 64, 4            # 隐层单元的数目——表示FFN中间层的输出维度；attention的数目"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3、导入文件\n",
    "文件为fra.txt文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:50:38.060788Z",
     "start_time": "2020-05-08T04:50:32.461294Z"
    }
   },
   "outputs": [],
   "source": [
    "src_vocab, tgt_vocab, train_iter = load_data_nmt(batch_size, num_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4、加载模型\n",
    "- TransformerEncoder为编码器模型\n",
    "- TransformerDecoder为解码器模型\n",
    "- transformer为编码器和解码器构成的最终模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:50:38.083069Z",
     "start_time": "2020-05-08T04:50:38.064776Z"
    }
   },
   "outputs": [],
   "source": [
    "encoder = TransformerEncoder(vocab_size=len(src_vocab), \n",
    "                             embedding_size=embedding_size, \n",
    "                             n_layers=num_layers, \n",
    "                             hidden_size=num_hiddens, \n",
    "                             num_heads=num_heads, \n",
    "                             dropout=dropout, )\n",
    "decoder = TransformerDecoder(vocab_size=len(src_vocab), \n",
    "                             embedding_size=embedding_size, \n",
    "                             n_layers=num_layers, \n",
    "                             hidden_size=num_hiddens, \n",
    "                             num_heads=num_heads, \n",
    "                             dropout=dropout, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:50:38.095197Z",
     "start_time": "2020-05-08T04:50:38.085535Z"
    }
   },
   "outputs": [],
   "source": [
    "class transformer(nn.Module):\n",
    "    def __init__(self, enc_net, dec_net):\n",
    "        super(transformer, self).__init__()\n",
    "        self.enc_net = enc_net   # TransformerEncoder的对象   \n",
    "        self.dec_net = dec_net   # TransformerDecoder的对象\n",
    "    \n",
    "    def forward(self, enc_X, dec_X, valid_length=None, max_seq_len=None):\n",
    "        \"\"\"\n",
    "        enc_X: 编码器的输入\n",
    "        dec_X: 解码器的输入\n",
    "        valid_length: 编码器的输入对应的valid_length,主要用于编码器attention的masksoftmax中，\n",
    "                      并且还用于解码器的第二个attention的masksoftmax中\n",
    "        max_seq_len:  位置编码时调整sin和cos周期大小的，默认大小为enc_X的第一个维度seq_len\n",
    "        \"\"\"\n",
    "        \n",
    "        # 1、通过编码器得到编码器最后一层的输出enc_output\n",
    "        enc_output = self.enc_net(enc_X, valid_length, max_seq_len)\n",
    "        # 2、state为解码器的初始状态，state包含两个元素，分别为[enc_output, valid_length]\n",
    "        state = self.dec_net.init_state(enc_output, valid_length)\n",
    "        # 3、通过解码器得到编码器最后一层到线性层的输出output，这里的output不是解码器最后一层的输出，而是\n",
    "        #    最后一层再连接线性层的输出\n",
    "        output = self.dec_net(dec_X, state)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:50:38.101624Z",
     "start_time": "2020-05-08T04:50:38.097736Z"
    }
   },
   "outputs": [],
   "source": [
    "model = transformer(encoder, decoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5、训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:55:48.207886Z",
     "start_time": "2020-05-08T04:50:38.104140Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch   50,loss 0.096, time 29.3 sec\n",
      "epoch  100,loss 0.049, time 30.8 sec\n",
      "epoch  150,loss 0.042, time 30.0 sec\n",
      "epoch  200,loss 0.036, time 30.9 sec\n",
      "epoch  250,loss 0.035, time 31.7 sec\n",
      "epoch  300,loss 0.033, time 30.1 sec\n",
      "epoch  350,loss 0.032, time 31.6 sec\n",
      "epoch  400,loss 0.031, time 31.9 sec\n",
      "epoch  450,loss 0.031, time 30.1 sec\n",
      "epoch  500,loss 0.031, time 30.7 sec\n"
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "train(model, train_iter, lr, factor, warmup, num_epochs, ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6、测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-08T04:55:48.353943Z",
     "start_time": "2020-05-08T04:55:48.212492Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Go . => va !\n",
      "Wow ! =>  !\n",
      "I'm OK . => ça .\n",
      "I won ! => je l'ai !\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "for sentence in ['Go .', 'Wow !', \"I'm OK .\", 'I won !']:\n",
    "    print(sentence + ' => ' + translate(model, sentence, src_vocab, tgt_vocab, num_steps, ctx))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:gpu-py37]",
   "language": "python",
   "name": "conda-env-gpu-py37-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
